% created by Nils Hase 2016 (nilshase@math.uni-bremen.de)

% Sparse dictionary regularization with dictionary configured for the US
% setup; the problem solved is
%
% [c_opt,alpha_opt] = arg min 1/2 ||ADc-y||_2^2 + alpha ||c||_1
%    s.t. x >= 0
%
% The optimal regularization parameter is chosen by Morozov's discrepancy
% principle

function [x_opt,alpha_opt,D,c_opt] = FISTA_DIC_POS(A,y,alpha,plotter)
    
    % Load dictionary for US setup and land grids
    load('dictionary.mat','D')
    load('synthetic_methane_data.mat','land_grids')
    
    %% FISTA with projection in each step
    
    % Find which columns of D are the pixel basis
    pixel_col = zeros(length(land_grids),1);
    for k = 1:length(land_grids)
        pixel_col(k) = find(D(k,:) == 1);
    end
    
    % Make D a sparse matrix (if not yet done)
    D = sparse(D);
    
    delta = 1;
    A = A*D;
    n = size(D,2);
    c_0 = zeros(n,1);
    lambda = 1/norm(A,'fro')^2;
    max_iter = 250; %200-300 should be enough for this setup
    tol = 10^-6;
    tau = 1;
    
    discrepancy = zeros(size(alpha));

    c_opt = c_0;
    for k = 1:length(alpha)
        c_opt = c_0;
        x_old = D*c_opt;
        
        c = c_opt;
        c_old = c_opt;
        
        for kk = 1:max_iter
            c_old_old = c_old;
            c_old = c;
            
            % FISTA-step
            cc = c_old + (kk-2)/(kk+1)*(c_old - c_old_old);
            c = cc - A'*(lambda*(A*cc - y));
            
            % Shrinkage step
            c = sign(c).*max(abs(c)-alpha(k)*lambda,0);
            
            % Projection step for nonnegative parameters
            x = D*c;
            x_neg = x; x_neg(x>0) = 0;
            if norm(x_neg) > 0
                c_neg = zeros(size(c));
                c_neg(pixel_col) = x_neg;
                c = c - c_neg;
            end

            if norm((x_old - x)) < tol
                break;
            end
            x_old = x;
        end
        
        discrepancy(k) = norm(A*c - y)/sqrt(length(y));
        c_opt = c;
        alpha_opt = alpha(k);
        
        % Morozov's discrepancy principle
        if discrepancy(k) < tau*delta
            break;
        end
    end

    if plotter == 1
        figure();
        plot(log10(alpha),discrepancy)
        hold on;
        plot([log10(alpha(1)),log10(alpha(end))],tau*delta*[1,1],'r--')
        title('Discrepancy curve FISTA DIC POS')
        hold off;
    end
    
    x_opt = D*c_opt;
    
end
